-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[Hexagon] Enable soft bf16 in hexagon #167922
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This patch adds: 1. Support to recognize bf16 type in the frontend and isel/abi support for scalar bf16 programs Limitations: fp_to_bf16 is being generated with a tablegen pattern instead of lowering via expansion. This is because we do not have support for fcanonincalize instruction which should prevent an SNaN being converted to an infinity due to truncation. 2. Vector codegen support for bf16 Patch By: Fateme Hosseini Co-authored-by: Muntasir Mallick <mallick@qti.qualcomm.com> Co-authored-by: Kaushik Kulkarni <quic_kauskulk@quicinc.com> Change-Id: I767145458dafcaf7691eb9ab4e03d33e5fd03a6a
|
@llvm/pr-subscribers-backend-hexagon @llvm/pr-subscribers-clang Author: Fateme Hosseini (fhossein-quic) ChangesThis patch adds:
Patch By: Fateme Hosseini Patch is 87.21 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/167922.diff 11 Files Affected:
diff --git a/clang/lib/Basic/Targets/Hexagon.cpp b/clang/lib/Basic/Targets/Hexagon.cpp
index d5b413cb58eb2..bd70bfe4fef51 100644
--- a/clang/lib/Basic/Targets/Hexagon.cpp
+++ b/clang/lib/Basic/Targets/Hexagon.cpp
@@ -155,9 +155,14 @@ bool HexagonTargetInfo::handleTargetFeatures(std::vector<std::string> &Features,
HasFastHalfType = true;
HasFloat16 = true;
}
+ if (CPU.compare("hexagonv81") >= 0)
+ HasBFloat16 = true;
+
return true;
}
+bool HexagonTargetInfo::hasBFloat16Type() const { return HasBFloat16; }
+
const char *const HexagonTargetInfo::GCCRegNames[] = {
// Scalar registers:
"r0", "r1", "r2", "r3", "r4", "r5", "r6", "r7", "r8", "r9", "r10", "r11",
diff --git a/clang/lib/Basic/Targets/Hexagon.h b/clang/lib/Basic/Targets/Hexagon.h
index a65663ca09eee..53c348a3246f9 100644
--- a/clang/lib/Basic/Targets/Hexagon.h
+++ b/clang/lib/Basic/Targets/Hexagon.h
@@ -64,6 +64,8 @@ class LLVM_LIBRARY_VISIBILITY HexagonTargetInfo : public TargetInfo {
// for modeling predicate registers in HVX, and the bool -> byte
// correspondence matches the HVX architecture.
BoolWidth = BoolAlign = 8;
+ BFloat16Width = BFloat16Align = 16;
+ BFloat16Format = &llvm::APFloat::BFloat();
}
llvm::SmallVector<Builtin::InfosShard> getTargetBuiltins() const override;
@@ -95,6 +97,8 @@ class LLVM_LIBRARY_VISIBILITY HexagonTargetInfo : public TargetInfo {
bool hasFeature(StringRef Feature) const override;
+ bool hasBFloat16Type() const override;
+
bool
initFeatureMap(llvm::StringMap<bool> &Features, DiagnosticsEngine &Diags,
StringRef CPU,
diff --git a/llvm/lib/Target/Hexagon/HexagonCallingConv.td b/llvm/lib/Target/Hexagon/HexagonCallingConv.td
index dceb70c8abbf2..80adde8246809 100644
--- a/llvm/lib/Target/Hexagon/HexagonCallingConv.td
+++ b/llvm/lib/Target/Hexagon/HexagonCallingConv.td
@@ -25,6 +25,8 @@ def CC_HexagonStack: CallingConv<[
def CC_Hexagon_Legacy: CallingConv<[
CCIfType<[i1,i8,i16],
CCPromoteToType<i32>>,
+ CCIfType<[bf16],
+ CCBitConvertToType<i32>>,
CCIfType<[f32],
CCBitConvertToType<i32>>,
CCIfType<[f64],
@@ -55,6 +57,8 @@ def CC_Hexagon_Legacy: CallingConv<[
def CC_Hexagon: CallingConv<[
CCIfType<[i1,i8,i16],
CCPromoteToType<i32>>,
+ CCIfType<[bf16],
+ CCBitConvertToType<i32>>,
CCIfType<[f32],
CCBitConvertToType<i32>>,
CCIfType<[f64],
@@ -88,6 +92,8 @@ def CC_Hexagon: CallingConv<[
def RetCC_Hexagon: CallingConv<[
CCIfType<[i1,i8,i16],
CCPromoteToType<i32>>,
+ CCIfType<[bf16],
+ CCBitConvertToType<i32>>,
CCIfType<[f32],
CCBitConvertToType<i32>>,
CCIfType<[f64],
@@ -149,16 +155,16 @@ def CC_Hexagon_HVX: CallingConv<[
CCIfType<[v128i1], CCPromoteToType<v128i8>>>,
CCIfHvx128<
- CCIfType<[v32i32,v64i16,v128i8,v32f32,v64f16],
+ CCIfType<[v32i32,v64i16,v128i8,v32f32,v64f16,v64bf16],
CCAssignToReg<[V0,V1,V2,V3,V4,V5,V6,V7,V8,V9,V10,V11,V12,V13,V14,V15]>>>,
CCIfHvx128<
- CCIfType<[v64i32,v128i16,v256i8,v64f32,v128f16],
+ CCIfType<[v64i32,v128i16,v256i8,v64f32,v128f16,v128bf16],
CCAssignToReg<[W0,W1,W2,W3,W4,W5,W6,W7]>>>,
CCIfHvx128<
- CCIfType<[v32i32,v64i16,v128i8,v32f32,v64f16],
+ CCIfType<[v32i32,v64i16,v128i8,v32f32,v64f16,v64bf16],
CCAssignToStack<128,128>>>,
CCIfHvx128<
- CCIfType<[v64i32,v128i16,v256i8,v64f32,v128f16],
+ CCIfType<[v64i32,v128i16,v256i8,v64f32,v128f16,v64bf16],
CCAssignToStack<256,128>>>,
CCDelegateTo<CC_Hexagon>
@@ -175,10 +181,10 @@ def RetCC_Hexagon_HVX: CallingConv<[
// HVX 128-byte mode
CCIfHvx128<
- CCIfType<[v32i32,v64i16,v128i8,v32f32,v64f16],
+ CCIfType<[v32i32,v64i16,v128i8,v32f32,v64f16,v64bf16],
CCAssignToReg<[V0]>>>,
CCIfHvx128<
- CCIfType<[v64i32,v128i16,v256i8,v64f32,v128f16],
+ CCIfType<[v64i32,v128i16,v256i8,v64f32,v128f16,v128bf16],
CCAssignToReg<[W0]>>>,
CCDelegateTo<RetCC_Hexagon>
diff --git a/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp b/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp
index 04a97606cb7f8..372524a6ac88a 100644
--- a/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonISelLowering.cpp
@@ -1677,6 +1677,8 @@ HexagonTargetLowering::HexagonTargetLowering(const TargetMachine &TM,
}
// Turn FP truncstore into trunc + store.
setTruncStoreAction(MVT::f64, MVT::f32, Expand);
+ setTruncStoreAction(MVT::f32, MVT::bf16, Expand);
+ setTruncStoreAction(MVT::f64, MVT::bf16, Expand);
// Turn FP extload into load/fpextend.
for (MVT VT : MVT::fp_valuetypes())
setLoadExtAction(ISD::EXTLOAD, VT, MVT::f32, Expand);
@@ -1872,9 +1874,15 @@ HexagonTargetLowering::HexagonTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FP16_TO_FP, MVT::f64, Expand);
setOperationAction(ISD::FP_TO_FP16, MVT::f32, Expand);
setOperationAction(ISD::FP_TO_FP16, MVT::f64, Expand);
+ setOperationAction(ISD::BF16_TO_FP, MVT::f32, Expand);
+ setOperationAction(ISD::BF16_TO_FP, MVT::f64, Expand);
+ setOperationAction(ISD::FP_TO_BF16, MVT::f64, Expand);
setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::f16, Expand);
setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f16, Expand);
+ setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::bf16, Expand);
+ setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::bf16, Expand);
+
setTruncStoreAction(MVT::f32, MVT::f16, Expand);
setTruncStoreAction(MVT::f64, MVT::f16, Expand);
diff --git a/llvm/lib/Target/Hexagon/HexagonISelLoweringHVX.cpp b/llvm/lib/Target/Hexagon/HexagonISelLoweringHVX.cpp
index 0573f64084d6f..bcc705c3fd99e 100644
--- a/llvm/lib/Target/Hexagon/HexagonISelLoweringHVX.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonISelLoweringHVX.cpp
@@ -88,6 +88,10 @@ HexagonTargetLowering::initializeHVXLowering() {
addRegisterClass(MVT::v64f32, &Hexagon::HvxWRRegClass);
addRegisterClass(MVT::v128f16, &Hexagon::HvxWRRegClass);
}
+ if (Subtarget.useHVXV81Ops()) {
+ addRegisterClass(MVT::v64bf16, &Hexagon::HvxVRRegClass);
+ addRegisterClass(MVT::v128bf16, &Hexagon::HvxWRRegClass);
+ }
}
// Set up operation actions.
@@ -162,6 +166,30 @@ HexagonTargetLowering::initializeHVXLowering() {
setPromoteTo(ISD::VECTOR_SHUFFLE, MVT::v64f32, ByteW);
setPromoteTo(ISD::VECTOR_SHUFFLE, MVT::v32f32, ByteV);
+ if (Subtarget.useHVXV81Ops()) {
+ setPromoteTo(ISD::VECTOR_SHUFFLE, MVT::v128bf16, ByteW);
+ setPromoteTo(ISD::VECTOR_SHUFFLE, MVT::v64bf16, ByteV);
+ setPromoteTo(ISD::SETCC, MVT::v64bf16, MVT::v64f32);
+ setPromoteTo(ISD::FADD, MVT::v64bf16, MVT::v64f32);
+ setPromoteTo(ISD::FSUB, MVT::v64bf16, MVT::v64f32);
+ setPromoteTo(ISD::FMUL, MVT::v64bf16, MVT::v64f32);
+ setPromoteTo(ISD::FMINNUM, MVT::v64bf16, MVT::v64f32);
+ setPromoteTo(ISD::FMAXNUM, MVT::v64bf16, MVT::v64f32);
+
+ setOperationAction(ISD::SPLAT_VECTOR, MVT::v64bf16, Legal);
+ setOperationAction(ISD::INSERT_SUBVECTOR, MVT::v64bf16, Custom);
+ setOperationAction(ISD::EXTRACT_SUBVECTOR, MVT::v64bf16, Custom);
+
+ setOperationAction(ISD::MLOAD, MVT::v64bf16, Custom);
+ setOperationAction(ISD::MSTORE, MVT::v64bf16, Custom);
+ setOperationAction(ISD::BUILD_VECTOR, MVT::v64bf16, Custom);
+ setOperationAction(ISD::CONCAT_VECTORS, MVT::v64bf16, Custom);
+
+ setOperationAction(ISD::SPLAT_VECTOR, MVT::bf16, Custom);
+ setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::bf16, Custom);
+ setOperationAction(ISD::BUILD_VECTOR, MVT::bf16, Custom);
+ }
+
for (MVT P : FloatW) {
setOperationAction(ISD::LOAD, P, Custom);
setOperationAction(ISD::STORE, P, Custom);
@@ -1667,14 +1695,15 @@ HexagonTargetLowering::LowerHvxBuildVector(SDValue Op, SelectionDAG &DAG)
// In case of MVT::f16 BUILD_VECTOR, since MVT::f16 is
// not a legal type, just bitcast the node to use i16
// types and bitcast the result back to f16
- if (VecTy.getVectorElementType() == MVT::f16) {
- SmallVector<SDValue,64> NewOps;
+ if (VecTy.getVectorElementType() == MVT::f16 ||
+ VecTy.getVectorElementType() == MVT::bf16) {
+ SmallVector<SDValue, 64> NewOps;
for (unsigned i = 0; i != Size; i++)
NewOps.push_back(DAG.getBitcast(MVT::i16, Ops[i]));
- SDValue T0 = DAG.getNode(ISD::BUILD_VECTOR, dl,
- tyVector(VecTy, MVT::i16), NewOps);
- return DAG.getBitcast(tyVector(VecTy, MVT::f16), T0);
+ SDValue T0 =
+ DAG.getNode(ISD::BUILD_VECTOR, dl, tyVector(VecTy, MVT::i16), NewOps);
+ return DAG.getBitcast(tyVector(VecTy, VecTy.getVectorElementType()), T0);
}
// First, split the BUILD_VECTOR for vector pairs. We could generate
@@ -1698,7 +1727,7 @@ HexagonTargetLowering::LowerHvxSplatVector(SDValue Op, SelectionDAG &DAG)
MVT VecTy = ty(Op);
MVT ArgTy = ty(Op.getOperand(0));
- if (ArgTy == MVT::f16) {
+ if (ArgTy == MVT::f16 || ArgTy == MVT::bf16) {
MVT SplatTy = MVT::getVectorVT(MVT::i16, VecTy.getVectorNumElements());
SDValue ToInt16 = DAG.getBitcast(MVT::i16, Op.getOperand(0));
SDValue ToInt32 = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i32, ToInt16);
@@ -1831,12 +1860,12 @@ HexagonTargetLowering::LowerHvxInsertElement(SDValue Op, SelectionDAG &DAG)
if (ElemTy == MVT::i1)
return insertHvxElementPred(VecV, IdxV, ValV, dl, DAG);
- if (ElemTy == MVT::f16) {
+ if (ElemTy == MVT::f16 || ElemTy == MVT::bf16) {
SDValue T0 = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl,
tyVector(VecTy, MVT::i16),
DAG.getBitcast(tyVector(VecTy, MVT::i16), VecV),
DAG.getBitcast(MVT::i16, ValV), IdxV);
- return DAG.getBitcast(tyVector(VecTy, MVT::f16), T0);
+ return DAG.getBitcast(tyVector(VecTy, ElemTy), T0);
}
return insertHvxElementReg(VecV, IdxV, ValV, dl, DAG);
@@ -2334,6 +2363,20 @@ SDValue HexagonTargetLowering::LowerHvxFpExtend(SDValue Op,
MVT VecTy = ty(Op);
MVT ArgTy = ty(Op.getOperand(0));
const SDLoc &dl(Op);
+
+ if (ArgTy == MVT::v64bf16) {
+ MVT HalfTy = typeSplit(VecTy).first;
+ SDValue BF16Vec = Op.getOperand(0);
+ SDValue Zeroes = getInstr(Hexagon::V6_vxor, dl, HalfTy, {BF16Vec, BF16Vec}, DAG);
+ // Interleave zero vector with the bf16 vector, with zeroes in the lower half
+ // of each 32 bit lane, effectively extending the bf16 values to fp32 values.
+ SDValue ShuffVec = getInstr(Hexagon::V6_vshufoeh, dl, VecTy, {BF16Vec, Zeroes}, DAG);
+ VectorPair VecPair = opSplit(ShuffVec, dl, DAG);
+ SDValue Result = getInstr(Hexagon::V6_vshuffvdd, dl, VecTy,
+ {VecPair.second, VecPair.first, DAG.getSignedConstant(-4, dl, MVT::i32)}, DAG);
+ return Result;
+ }
+
assert(VecTy == MVT::v64f32 && ArgTy == MVT::v64f16);
SDValue F16Vec = Op.getOperand(0);
diff --git a/llvm/lib/Target/Hexagon/HexagonPatterns.td b/llvm/lib/Target/Hexagon/HexagonPatterns.td
index e40dbd251b5b7..e84070f1a5468 100644
--- a/llvm/lib/Target/Hexagon/HexagonPatterns.td
+++ b/llvm/lib/Target/Hexagon/HexagonPatterns.td
@@ -391,7 +391,6 @@ def Fptoui: pf1<fp_to_uint>;
def Sitofp: pf1<sint_to_fp>;
def Uitofp: pf1<uint_to_fp>;
-
// --(1) Immediate -------------------------------------------------------
//
@@ -474,6 +473,18 @@ def: OpR_R_pat<F2_conv_df2uw_chop, pf1<fp_to_uint>, i32, F64>;
def: OpR_R_pat<F2_conv_sf2ud_chop, pf1<fp_to_uint>, i64, F32>;
def: OpR_R_pat<F2_conv_df2ud_chop, pf1<fp_to_uint>, i64, F64>;
+def: Pat<(i32 (fp_to_bf16 F32:$v)),
+ (C2_mux (F2_sfclass F32:$v, 0x10), (A2_tfrsi(i32 0x7fff)),
+ (C2_mux
+ (C2_cmpeq
+ (A2_and F32:$v, (A2_tfrsi (i32 0x1FFFF))),
+ (A2_tfrsi (i32 0x08000))),
+ (A2_and (A2_asrh F32:$v), (A2_tfrsi (i32 65535))),
+ (A2_and
+ (A2_asrh
+ (A2_add F32:$v, (A2_and F32:$v, (A2_tfrsi (i32 0x8000))))),
+ (A2_tfrsi (i32 65535))))
+ )>;
// Bitcast is different than [fp|sint|uint]_to_[sint|uint|fp].
def: Pat<(i32 (bitconvert F32:$v)), (I32:$v)>;
def: Pat<(f32 (bitconvert I32:$v)), (F32:$v)>;
diff --git a/llvm/lib/Target/Hexagon/HexagonPatternsHVX.td b/llvm/lib/Target/Hexagon/HexagonPatternsHVX.td
index d19920cfc9ea0..4cb29e7f00317 100644
--- a/llvm/lib/Target/Hexagon/HexagonPatternsHVX.td
+++ b/llvm/lib/Target/Hexagon/HexagonPatternsHVX.td
@@ -15,12 +15,14 @@ def HVI16: PatLeaf<(VecI16 HvxVR:$R)>;
def HVI32: PatLeaf<(VecI32 HvxVR:$R)>;
def HVF16: PatLeaf<(VecF16 HvxVR:$R)>;
def HVF32: PatLeaf<(VecF32 HvxVR:$R)>;
+def HVBF16: PatLeaf<(VecBF16 HvxVR:$R)>;
def HWI8: PatLeaf<(VecPI8 HvxWR:$R)>;
def HWI16: PatLeaf<(VecPI16 HvxWR:$R)>;
def HWI32: PatLeaf<(VecPI32 HvxWR:$R)>;
def HWF16: PatLeaf<(VecPF16 HvxWR:$R)>;
def HWF32: PatLeaf<(VecPF32 HvxWR:$R)>;
+def HWBF16: PatLeaf<(VecBF16 HvxWR:$R)>;
def SDTVecUnaryOp:
SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisVec<1>]>;
@@ -182,12 +184,15 @@ let Predicates = [UseHVX] in {
}
let Predicates = [UseHVXV68] in {
- defm: HvxLda_pat<V6_vL32b_nt_ai, alignednontemporalload, VecF16, IsVecOff>;
- defm: HvxLda_pat<V6_vL32b_nt_ai, alignednontemporalload, VecF32, IsVecOff>;
- defm: HvxLda_pat<V6_vL32b_ai, alignedload, VecF16, IsVecOff>;
- defm: HvxLda_pat<V6_vL32b_ai, alignedload, VecF32, IsVecOff>;
- defm: HvxLd_pat<V6_vL32Ub_ai, unalignedload, VecF16, IsVecOff>;
- defm: HvxLd_pat<V6_vL32Ub_ai, unalignedload, VecF32, IsVecOff>;
+ defm : HvxLda_pat<V6_vL32b_nt_ai, alignednontemporalload, VecBF16, IsVecOff>;
+ defm : HvxLda_pat<V6_vL32b_nt_ai, alignednontemporalload, VecF16, IsVecOff>;
+ defm : HvxLda_pat<V6_vL32b_nt_ai, alignednontemporalload, VecF32, IsVecOff>;
+ defm : HvxLda_pat<V6_vL32b_ai, alignedload, VecBF16, IsVecOff>;
+ defm : HvxLda_pat<V6_vL32b_ai, alignedload, VecF16, IsVecOff>;
+ defm : HvxLda_pat<V6_vL32b_ai, alignedload, VecF32, IsVecOff>;
+ defm : HvxLd_pat<V6_vL32Ub_ai, unalignedload, VecBF16, IsVecOff>;
+ defm : HvxLd_pat<V6_vL32Ub_ai, unalignedload, VecF16, IsVecOff>;
+ defm : HvxLd_pat<V6_vL32Ub_ai, unalignedload, VecF32, IsVecOff>;
}
// HVX stores
@@ -233,10 +238,13 @@ let Predicates = [UseHVX] in {
}
let Predicates = [UseHVXV68] in {
+ defm: HvxSt_pat<V6_vS32b_nt_ai, alignednontemporalstore, HVBF16, IsVecOff>;
defm: HvxSt_pat<V6_vS32b_nt_ai, alignednontemporalstore, HVF16, IsVecOff>;
defm: HvxSt_pat<V6_vS32b_nt_ai, alignednontemporalstore, HVF32, IsVecOff>;
+ defm: HvxSt_pat<V6_vS32b_ai, alignedstore, HVBF16, IsVecOff>;
defm: HvxSt_pat<V6_vS32b_ai, alignedstore, HVF16, IsVecOff>;
defm: HvxSt_pat<V6_vS32b_ai, alignedstore, HVF32, IsVecOff>;
+ defm: HvxSt_pat<V6_vS32Ub_ai, unalignedstore, HVBF16, IsVecOff>;
defm: HvxSt_pat<V6_vS32Ub_ai, unalignedstore, HVF16, IsVecOff>;
defm: HvxSt_pat<V6_vS32Ub_ai, unalignedstore, HVF32, IsVecOff>;
}
@@ -255,18 +263,24 @@ let Predicates = [UseHVX] in {
let Predicates = [UseHVX, UseHVXFloatingPoint] in {
defm: NopCast_pat<VecI8, VecF16, HvxVR>;
+ defm: NopCast_pat<VecI8, VecBF16, HvxVR>;
defm: NopCast_pat<VecI8, VecF32, HvxVR>;
defm: NopCast_pat<VecI16, VecF16, HvxVR>;
+ defm: NopCast_pat<VecI16, VecBF16, HvxVR>;
defm: NopCast_pat<VecI16, VecF32, HvxVR>;
defm: NopCast_pat<VecI32, VecF16, HvxVR>;
+ defm: NopCast_pat<VecI32, VecBF16, HvxVR>;
defm: NopCast_pat<VecI32, VecF32, HvxVR>;
defm: NopCast_pat<VecF16, VecF32, HvxVR>;
defm: NopCast_pat<VecPI8, VecPF16, HvxWR>;
+ defm: NopCast_pat<VecPI8, VecPBF16, HvxWR>;
defm: NopCast_pat<VecPI8, VecPF32, HvxWR>;
defm: NopCast_pat<VecPI16, VecPF16, HvxWR>;
+ defm: NopCast_pat<VecPI16, VecPBF16, HvxWR>;
defm: NopCast_pat<VecPI16, VecPF32, HvxWR>;
defm: NopCast_pat<VecPI32, VecPF16, HvxWR>;
+ defm: NopCast_pat<VecPI32, VecPBF16, HvxWR>;
defm: NopCast_pat<VecPI32, VecPF32, HvxWR>;
defm: NopCast_pat<VecPF16, VecPF32, HvxWR>;
}
@@ -315,11 +329,14 @@ let Predicates = [UseHVX] in {
let Predicates = [UseHVX, UseHVXFloatingPoint] in {
let AddedComplexity = 100 in {
def: Pat<(VecF16 vzero), (V6_vd0)>;
+ def: Pat<(VecBF16 vzero), (V6_vd0)>;
def: Pat<(VecF32 vzero), (V6_vd0)>;
def: Pat<(VecPF16 vzero), (PS_vdd0)>;
+ def: Pat<(VecPBF16 vzero), (PS_vdd0)>;
def: Pat<(VecPF32 vzero), (PS_vdd0)>;
def: Pat<(concat_vectors (VecF16 vzero), (VecF16 vzero)), (PS_vdd0)>;
+ def : Pat<(concat_vectors (VecBF16 vzero), (VecBF16 vzero)), (PS_vdd0)>;
def: Pat<(concat_vectors (VecF32 vzero), (VecF32 vzero)), (PS_vdd0)>;
}
@@ -355,11 +372,13 @@ let Predicates = [UseHVX] in {
let Predicates = [UseHVXV68, UseHVXFloatingPoint] in {
let AddedComplexity = 30 in {
def: Pat<(VecF16 (splat_vector u16_0ImmPred:$V)), (PS_vsplatih imm:$V)>;
+ def: Pat<(VecBF16 (splat_vector u16_0ImmPred:$V)), (PS_vsplatih imm:$V)>;
def: Pat<(VecF32 (splat_vector anyint:$V)), (PS_vsplatiw imm:$V)>;
def: Pat<(VecF32 (splat_vector f32ImmPred:$V)), (PS_vsplatiw (ftoi $V))>;
}
let AddedComplexity = 20 in {
def: Pat<(VecF16 (splat_vector I32:$Rs)), (PS_vsplatrh $Rs)>;
+ def: Pat<(VecBF16 (splat_vector I32:$Rs)), (PS_vsplatrh $Rs)>;
def: Pat<(VecF32 (splat_vector I32:$Rs)), (PS_vsplatrw $Rs)>;
def: Pat<(VecF32 (splat_vector F32:$Rs)), (PS_vsplatrw $Rs)>;
}
@@ -519,6 +538,35 @@ let Predicates = [UseHVXV68, UseHVXIEEEFP] in {
def: Pat<(VecPF16 (Uitofp HVI8:$Vu)), (V6_vcvt_hf_ub HvxVR:$Vu)>;
}
+let Predicates = [UseHVXV81] in {
+ def : Pat<(VecBF16 (pf1<fpround> HWF32:$Vuu)),
+ (V6_vpackwuh_sat (V6_vmux
+ (V6_veqsf (HiVec HvxWR:$Vuu), (HiVec HvxWR:$Vuu)),
+ (V6_vlsrw (V6_vmux (V6_veqw (V6_vand (HiVec HvxWR:$Vuu),
+ (PS_vsplatiw (i32 0x1FFFF))),
+ (PS_vsplatiw (i32 0x08000))),
+ (HiVec HvxWR:$Vuu),
+ (V6_vaddw (HiVec HvxWR:$Vuu),
+ (V6_vand (HiVec HvxWR:$Vuu),
+ (PS_vsplatiw (i32 0x8000))))),
+ (A2_tfrsi 16)),
+ (PS_vsplatih (i32 0x7fff))),
+ (V6_vmux (V6_veqsf (LoVec HvxWR:$Vuu), (LoVec HvxWR:$Vuu)),
+ (V6_vlsrw (V6_vmux (V6_veqw (V6_vand (LoVec HvxWR:$Vuu),
+ (PS_vsplatiw (i32 0x1FFFF))),
+ (PS_vsplatiw (i32 0x08000))),
+ (LoVec HvxWR:$Vuu),
+ (V6_vaddw (LoVec HvxWR:$Vuu),
+ (V6_vand (LoVec HvxWR:$Vuu),
+ (PS_vsplatiw (i32 0x8000))))),
+ (A2_tfrsi 16)),
+ (PS_vsplatih (i32 0x7fff))))>;
+}
+
+let Predicates = [UseHVXV73, UseHVXQFloat] in {
+ def : Pat<(VecF32 (Sitofp HVI32:$Vu)), (V6_vconv_sf_w HvxVR:$Vu)>;
+}
+
let Predicates = [UseHVXV68, UseHVXFloatingPoint] in {
def: Pat<(vselect HQ16:$Qu, HVF16:$Vs, HVF16:$Vt),
(V6_vmux HvxQR:$Qu, HvxVR:$Vs, HvxVR:$Vt)>;
@@ -531,6 +579,13 @@ let Predicates = [UseHVXV68, UseHVXFloatingPoint] in {
(V6_vmux HvxQR:$Qu, HvxVR:$Vt, HvxVR:$Vs)>;
}
+let Predicates = [UseHVXV81, UseHVXFloatingPoint] in {
+ def : Pat<(vselect HQ16:$Qu, HVBF16:$Vs, HVBF16:$Vt),
+ (V6_vmux HvxQR:$Qu, HvxVR:$Vs, HvxVR:$Vt)>;
+ def : Pat<(vselect (qnot HQ16:$Qu), HVBF16:$Vs, HVBF16:$Vt),
+ (V6_vmux HvxQR:$Qu, HvxVR:$Vt, HvxVR:$Vs)>;
+}
+
let Predicates = [UseHVXV68, UseHVX128B, UseHVXQFloat] in {
let AddedComplexity = 220 in {
defm: MinMax_pats<V6_vmin_hf, V6_vmax_hf, vselect, setgt, VecQ16, HVF16>;
diff --git a/llvm/lib/Target/Hexagon/HexagonRegisterInfo.td b/llvm/lib/Target/Hexagon/HexagonRegisterInfo.td
index 3a77fcd04e35c..1f1aebd0e5ec9 100644
--- a/llvm/lib/Target/Hexagon/HexagonRegisterInfo.td
+++ b/llvm/lib/Target/Hexagon/HexagonRegisterInfo.td
@@ -15,141 +15,126 @@ let Namespace = "Hexagon" in {
class HexagonReg<bits<5> num, string n, list<string> alt = [],
list<Register> alias = []> : Register<n, alt> {
let Aliases = alias;
- let HWEncoding{4-0} = num;
+ let HWEncoding{4 -0} = num;
}
// These registers are used to preserve a distinction between
// vector register pairs of differing order.
- class ...
[truncated]
|
You can test this locally with the following command:git-clang-format --diff origin/main HEAD --extensions h,cpp -- clang/lib/Basic/Targets/Hexagon.cpp clang/lib/Basic/Targets/Hexagon.h llvm/lib/Target/Hexagon/HexagonISelLowering.cpp llvm/lib/Target/Hexagon/HexagonISelLoweringHVX.cpp llvm/lib/Target/Hexagon/HexagonSubtarget.h --diff_from_common_commit
View the diff from clang-format here.diff --git a/llvm/lib/Target/Hexagon/HexagonISelLoweringHVX.cpp b/llvm/lib/Target/Hexagon/HexagonISelLoweringHVX.cpp
index bcc705c3f..63b7683d5 100644
--- a/llvm/lib/Target/Hexagon/HexagonISelLoweringHVX.cpp
+++ b/llvm/lib/Target/Hexagon/HexagonISelLoweringHVX.cpp
@@ -2365,16 +2365,21 @@ SDValue HexagonTargetLowering::LowerHvxFpExtend(SDValue Op,
const SDLoc &dl(Op);
if (ArgTy == MVT::v64bf16) {
- MVT HalfTy = typeSplit(VecTy).first;
- SDValue BF16Vec = Op.getOperand(0);
- SDValue Zeroes = getInstr(Hexagon::V6_vxor, dl, HalfTy, {BF16Vec, BF16Vec}, DAG);
- // Interleave zero vector with the bf16 vector, with zeroes in the lower half
- // of each 32 bit lane, effectively extending the bf16 values to fp32 values.
- SDValue ShuffVec = getInstr(Hexagon::V6_vshufoeh, dl, VecTy, {BF16Vec, Zeroes}, DAG);
- VectorPair VecPair = opSplit(ShuffVec, dl, DAG);
- SDValue Result = getInstr(Hexagon::V6_vshuffvdd, dl, VecTy,
- {VecPair.second, VecPair.first, DAG.getSignedConstant(-4, dl, MVT::i32)}, DAG);
- return Result;
+ MVT HalfTy = typeSplit(VecTy).first;
+ SDValue BF16Vec = Op.getOperand(0);
+ SDValue Zeroes =
+ getInstr(Hexagon::V6_vxor, dl, HalfTy, {BF16Vec, BF16Vec}, DAG);
+ // Interleave zero vector with the bf16 vector, with zeroes in the lower
+ // half of each 32 bit lane, effectively extending the bf16 values to fp32
+ // values.
+ SDValue ShuffVec =
+ getInstr(Hexagon::V6_vshufoeh, dl, VecTy, {BF16Vec, Zeroes}, DAG);
+ VectorPair VecPair = opSplit(ShuffVec, dl, DAG);
+ SDValue Result = getInstr(Hexagon::V6_vshuffvdd, dl, VecTy,
+ {VecPair.second, VecPair.first,
+ DAG.getSignedConstant(-4, dl, MVT::i32)},
+ DAG);
+ return Result;
}
assert(VecTy == MVT::v64f32 && ArgTy == MVT::v64f16);
diff --git a/llvm/lib/Target/Hexagon/HexagonSubtarget.h b/llvm/lib/Target/Hexagon/HexagonSubtarget.h
index 5cf0995c6..b2fc07ee1 100644
--- a/llvm/lib/Target/Hexagon/HexagonSubtarget.h
+++ b/llvm/lib/Target/Hexagon/HexagonSubtarget.h
@@ -343,7 +343,8 @@ public:
ArrayRef<MVT> getHVXElementTypes() const {
static MVT Types[] = {MVT::i8, MVT::i16, MVT::i32};
- static MVT TypesV81[] = {MVT::i8, MVT::i16, MVT::i32, MVT::f16, MVT::bf16, MVT::f32};
+ static MVT TypesV81[] = {MVT::i8, MVT::i16, MVT::i32,
+ MVT::f16, MVT::bf16, MVT::f32};
if (useHVXV81Ops() && useHVXFloatingPoint())
return ArrayRef(TypesV81);
|
This patch adds:
Support to recognize bf16 type in the frontend and isel/abi support for scalar bf16 programs
Limitations: fp_to_bf16 is being generated with a tablegen pattern instead of lowering via expansion. This is because we do not have support for fcanonincalize instruction which should prevent an SNaN being converted to an infinity due to truncation.
Vector codegen support for bf16
Patch By: Fateme Hosseini